import numpy as np
#from matplotlib.pyplot import plot, savefig, title, legend, ylim, cla, xlabel, ylabel, annotate
from time import clock
#from sklearn.utils.extmath import fast_dot
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
#@profile
def Method_Iterative(A, b, Precision):

	Iteration_Matrix, Constant_Matrix, Initial_Vector = Generate_Matrix_for_Iterate(A, b)
	#print(Iteration_Matrix)
	Error = Precision + 1
	X = Initial_Vector
	conut = 0
	while Error > Precision:
		X_next = Iteration_Matrix*X + Constant_Matrix
		#Error = np.linalg.norm(X_next - X)
		Error = np.max(np.abs(X_next - X))
		X = X_next
		conut = conut + 1
		#print(Error)
	return X, Error, conut

def Generate_Matrix_for_Iterate(A, b):
	'''
	A = M - N 
	N = M - A
	x = M^{-1}*N*x + M^{-1}*b
	x = Iteration_Matrix*x + Constant_Matrix
	'''
	Dimention = A.shape[0]
	M = A.item(0,0)*np.matrix(np.eye(Dimention,Dimention))
	N = M - A
	M_inv = 1/A.item(0,0)#*np.matrix(np.eye(Dimention,Dimention))
	Iteration_Matrix = M_inv*N#np.dot(M_inv,N)#
	Constant_Matrix = M_inv*b
	Initial_Vector = np.matrix(np.zeros([Dimention,1]))
	#print(Iteration_Matrix)
	return Iteration_Matrix, Constant_Matrix, Initial_Vector


def Generate_Matrix_A_b(Nx, Ny, Xmax, Xmin, Ymax, Ymin):
	"""
	Nx和Ny分别表示了在X方向和Y方向上划分的点数，因而网格数需要加1
	"""
	#Diagonal Matrix Kernal
	#D = np.matrix(Diagonal_Matrix_String)
	#Kernal_Dim = D.shape[0]
	#The Matrix near by Diagonal Matrix Kernal
	#H = np.matrix(np.eye(Kernal_Dim))
	#The Zero Matrix
	#Z = np.matrix(np.zeros([Kernal_Dim,Kernal_Dim]))
	#print(D, H, Z)
	#Kernel_String = "H,D,H,Z,Z"
	dx = (Xmax - Xmin)/(Nx+1)
	dy = (Ymax - Ymin)/(Ny+1)
	D, H, Z = Generate_Diagonal_Matrix(Nx, Ny, dx, dy)
	Kernel_String = "H,D,H,Z"
	Some_Z = ",Z"*(Nx-3)
	Kernel_String = Kernel_String + Some_Z
	Matrix_String = Kernel_String[2:]
	#First -1 because Kernel_String is one more than real, Second -1 becasue there is a value out of loop
	for i in range(int((len(Kernel_String)-1)/2)-1):
		Matrix_String = Matrix_String+";"+"Z,"*i+Kernel_String[0:len(Kernel_String)-(i+1)*2]
	# print("D:\n",D)
	# print("H:\n",H)
	# print("Z:\n",Z)
	# print("A:") 
	# print(Matrix_String.replace(";", "\n").replace(",", ", "))
	A = np.bmat(Matrix_String)

			# Vector_b_String = str(int(np.random.random()*10))
			# #First -1 because Kernel_String is one more than real, Second -1 becasue there is a value out of loop
			# for i in range(int((len(Kernel_String)-1)/2*D.shape[0])-1):
			# 	Vector_b_String = Vector_b_String+";"+str(int(np.random.random()*10))
			# b = np.matrix(Vector_b_String)
			# with open("matrix_file", 'w') as mf:
			# 	A.tofile(mf)
			# 	b.tofile(mf)
			# np.savetxt("matrix_file", A)
	b = np.zeros(Nx*Ny)
	Border_Up, Border_Down, Border_Left, Border_Right = Border_Generator(Nx, Ny)
	for i in range(Nx):
		for j in range(Ny):
			if 0 == i:
				b[Ny*i+j] = b[Ny*i+j] + (-dy**2)*Border_Left[j+1]
			if (Nx-1) == i:
				b[Ny*i+j] = b[Ny*i+j] + (-dy**2)*Border_Right[j+1]
			if 0 == j:
				b[Ny*i+j] = b[Ny*i+j] + (-dx**2)*Border_Up[i+1]
			if (Ny-1) == j:
				b[Ny*i+j] = b[Ny*i+j] + (-dx**2)*Border_Down[i+1]
	b = np.mat(b).T
	return A, b

# A = np.matrix("-4,1,0,1,0,0,0,0,0;1,-4,1,0,1,0,0,0,0;0,1,-4,0,0,1,0,0,0;1,0,0,-4,1,0,1,0,0;0,1,0,1,-4,1,0,1,0;0,0,1,0,1,-4,0,0,1;0,0,0,1,0,0,-4,1,0;0,0,0,0,1,0,1,-4,1;0,0,0,0,0,1,0,1,-4")
# b=np.matrix("5;6;7;8;9;7;8;9;2")
def Generate_Diagonal_Matrix(Nx, Ny, dx, dy):
	#Block_Line_String = "1,-4,1,0"
	#Square_Product = dx**2*dy**2
	Block_Line_String = str(dx**2)+","+str((-2)*(dx**2+dy**2))+","+str(dx**2)+",0"
	#-3是因为在上一行已经产生了4个元素，而第1个元素是为了生成方便加上的，因此在上一行产生了3个元素
	Block_Line_String = Block_Line_String + ",0"*(Ny-3)
	index = Block_Line_String.find(",")+1
	Diagonal_Matrix_String = Block_Line_String[index:]
	position = Block_Line_String.rfind(",")
	for i in range(Block_Line_String.count(",")-1):
		Diagonal_Matrix_String = Diagonal_Matrix_String+";"+"0,"*i+Block_Line_String[0:position]
		position = Block_Line_String.rfind(",",0,position)
	#print(Diagonal_Matrix_String)

	D = np.matrix(Diagonal_Matrix_String)
	Kernal_Dim = D.shape[0]
	H = np.matrix(np.eye(Kernal_Dim))*(dy**2)
	Z = np.matrix(np.zeros([Kernal_Dim,Kernal_Dim]))
	return D, H, Z


def Border_Generator(Nx, Ny, Border_Type="array"):
	Border_Up = np.linspace(0,0,Nx+2)
	Border_Down = np.linspace(5,5,Nx+2)
	Border_Left = np.linspace(0,5,Ny+2)#+np.cos(np.linspace(0,5,Ny+2))
	Border_Right = np.linspace(0,5,Ny+2)#+np.sin(np.linspace(0,5,Ny+2))
	#print(Border_Up, Border_Down, Border_Left, Border_Right)
	if Border_Type == "array":
		return Border_Up, Border_Down, Border_Left, Border_Right
	if Border_Type == "matrix":
		return np.mat(Border_Up[1:-1]), np.mat(Border_Down[1:-1]), np.mat(Border_Left).T, np.mat(Border_Right).T

def Plot_Result(Z, Nx, Ny, Xmax, Xmin, Ymax, Ymin):
	Border_Up, Border_Down, Border_Left, Border_Right = Border_Generator(Nx, Ny, Border_Type="matrix")
	#print(Border_Up, Border_Left)
	xs = np.linspace(Xmin, Xmax, Nx+2)
	ys = np.linspace(Ymin, Ymax, Ny+2)
	X, Y = np.meshgrid(xs, ys)
	#print(X,Y)
	Z = np.reshape(Z,[Ny,Nx],"F")#reshape the matrix as Fortran
	#print(Z)
	#print(np.bmat([[Border_Up], [Z], [Border_Down]]))
	Z = np.bmat([[Border_Left, np.bmat([[Border_Up], [Z], [Border_Down]]), Border_Right]])
	Z = np.array(Z)
	#print(Z)
	fig = plt.figure()
	ax = fig.gca(projection='3d')

	cset = ax.contourf(X, Y, Z, zdir='z', offset=0, cmap=cm.coolwarm)
	cset = ax.contourf(X, Y, Z, zdir='x', offset=-1, cmap=cm.coolwarm)
	cset = ax.contour(X, Y, Z, zdir='y', offset=6, cmap=cm.coolwarm)
	surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, alpha=0.5, cmap=cm.coolwarm, linewidth=0.01)
	#ax.zaxis.set_major_locator(LinearLocator(10))
	#ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
	fig.colorbar(surf, shrink=0.5, aspect=5)
	ax.set_xlabel('X')
	ax.set_xlim(-1, 5.5)
	ax.set_ylabel('Y')
	ax.set_ylim(0, 6)
	ax.set_zlabel('Z')
	ax.set_zlim(0, 6)
	plt.show()

if __name__ == "__main__":
	#np.set_printoptions(threshold='nan')
	#Nx = 70
	#Ny = 70 #Y方向分成5份
	for N in range(4,102,2):
		A, b = Generate_Matrix_A_b(N, N, 5, 0, 5, 0)
		#print(A)
		#print(b)
		# start_time = clock()
		# Z = A.I*b
		# print("Direct Method:\n", Z, A.shape)
		# end_time = clock()
		# print("\nUse time:", end_time - start_time)

		#print("Begin Calcule!")
		print(N)
		start_time = clock()
		Z, Error, count = Method_Iterative(A, b, 1e-5)
		end_time = clock()
		#print("\nIterative Method:", "\n\tMatrix Size:", A.shape, "\n\tError:", Error, "\n\tIterative times:", count)
		#print("\nUse time:", end_time - start_time, "s.")
		with open("N_and_Time", "a") as benchmark_file:
			print(N, end_time - start_time, file=benchmark_file)

	#Plot_Result(Z, Nx, Ny, Xmax=5, Xmin=0, Ymax=5, Ymin=0)

	# #Diagonal Matrix Kernal
	# D = np.matrix("-4,1,0;1,-4,0;0,1,-4")
	# #The Matrix near by Diagonal Matrix Kernal
	# H = np.matrix("1,0,0;0,1,0;0,0,1")
	# #The Zero Matrix
	# Z = np.matrix("0,0,0;0,0,0;0,0,0")

	# Kernel_String="H,D,H,Z,Z,Z,Z"
	# Matrix_String=Kernel_String[2:]
	# for i in range(int((len(Kernel_String)-1)/2)-1):
	# 	Matrix_String = Matrix_String+";"+"Z,"*i+Kernel_String[0:len(Kernel_String)-(i+1)*2]
	# A = np.bmat(Matrix_String)
